9c45f9
@@ -18,7 +18,9 @@
 package org.apache.hadoop.hive.ql.optimizer.calcite.reloperators;
 
 import java.util.List;
+import java.util.Set;
 
+import org.apache.calcite.linq4j.Ord;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptCost;
 import org.apache.calcite.plan.RelOptPlanner;
@@ -29,10 +31,16 @@
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.core.RelFactories.AggregateFactory;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.IntList;
 import org.apache.hadoop.hive.ql.optimizer.calcite.TraitsUtil;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
 
 public class HiveAggregate extends Aggregate implements HiveRelNode {
 
@@ -81,6 +89,56 @@
public boolean isBucketedInput() {
             containsAll(groupSet.asList());
   }
 
+  @Override
+  protected RelDataType deriveRowType() {
+    return deriveRowType(getCluster().getTypeFactory(), getInput().getRowType(),
+        indicator, groupSet, groupSets, aggCalls);
+  }
+
+  public static RelDataType deriveRowType(RelDataTypeFactory typeFactory,
+      final RelDataType inputRowType, boolean indicator,
+      ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
+      final List<AggregateCall> aggCalls) {
+    final IntList groupList = groupSet.toList();
+    assert groupList.size() == groupSet.cardinality();
+    final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
+    final List<RelDataTypeField> fieldList = inputRowType.getFieldList();
+    final Set<String> containedNames = Sets.newHashSet();
+    for (int groupKey : groupList) {
+      containedNames.add(fieldList.get(groupKey).getName());
+      builder.add(fieldList.get(groupKey));
+    }
+    if (indicator) {
+      for (int groupKey : groupList) {
+        final RelDataType booleanType =
+            typeFactory.createTypeWithNullability(
+                typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
+        String name = "i$" + fieldList.get(groupKey).getName();
+        int i = 0;
+        while (containedNames.contains(name)) {
+          name += "_" + i++;
+        }
+        containedNames.add(name);
+        builder.add(name, booleanType);
+      }
+    }
+    for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
+      String name;
+      if (aggCall.e.name != null) {
+        name = aggCall.e.name;
+      } else {
+        name = "$f" + (groupList.size() + aggCall.i);
+      }
+      int i = 0;
+      while (containedNames.contains(name)) {
+        name += "_" + i++;
+      }
+      containedNames.add(name);
+      builder.add(name, aggCall.e.type);
+    }
+    return builder.build();
+  }
+
   private static class HiveAggRelFactory implements AggregateFactory {
 
     @Override
